-
Notifications
You must be signed in to change notification settings - Fork 0
Add model metadata #135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model metadata #135
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR extends model metadata to include multihit and neutral model settings and integrates these changes across tests and core model functions.
- Updates tests to load and use multihit models via load_multihit.
- Extends AbstractBinarySelectionModel and SingleValueBinarySelectionModel with new metadata (including model_type, train_timestamp, neutral_model_name, and multihit_model_name) and adjusts hyperparameter defaults.
- Enhances framework functions (including add_shm_model_outputs_to_pcp_df and DXSMBurrito initialization) to verify model metadata consistency.
Reviewed Changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
tests/test_simulation.py | Uses load_multihit to retrieve multihit model and adds tolerance in allclose check; reassigns train_dataset to val_dataset. |
tests/test_multihit.py | Updates model instantiation to pass model_type and generate multihit_model_name from model weights. |
tests/test_dnsm.py, test_ddsm.py, test_dasm.py, test_ambiguous.py | Integrates new parameter model_type and multihit_model into model/dataset creation. |
netam/pretrained.py | Introduces load_multihit and name_and_multihit_model_match for multihit model handling. |
netam/models.py | Extends metadata in model constructors and updates reinitialize_weights, to_weights, and from_weights methods. |
netam/framework.py | Adds default hyperparameter values for legacy models and filters sequences in add_shm_model_outputs_to_pcp_df. |
netam/dxsm.py | Implements metadata validation with warnings regarding model_type and multihit model consistency. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few final todos 👍
@@ -66,13 +63,7 @@ def apply_multihit_correction( | |||
per_parent_hit_class = parent_specific_hit_classes(parent_codon_idxs) | |||
corrections = torch.cat([torch.tensor([0.0]), log_hit_class_factors]).exp() | |||
reshaped_corrections = corrections[per_parent_hit_class] | |||
unnormalized_corrected_probs = clamp_probability(codon_probs * reshaped_corrections) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a refactor -- the forward method of the multihit model still sets the parent codon probability, but this allows the model to expose a method that adjusts codon probs but does not set the parent codon probability.
Leaving this here just to document that I did try implementing simulation probabilities with def codon_probs_of_parent_seq_new(
selection_crepe, nt_sequence, branch_length, neutral_crepe=None, multihit_model=None
):
"""Calculate the predicted model probabilities of each codon at each site.
Args:
nt_sequence: A tuple of two strings, the heavy and light chain nucleotide
sequences.
branch_length: The branch length of the tree.
Returns:
a tuple of tensors of shape (L, 64) representing the predicted probabilities of each
codon at each site.
"""
if neutral_crepe is None:
raise NotImplementedError("neutral_crepe is required.")
if isinstance(nt_sequence, str) or len(nt_sequence) != 2:
raise ValueError(
"nt_sequence must be a pair of strings, with the first element being the heavy chain sequence and the second element being the light chain sequence."
)
aa_seqs = tuple(translate_sequences_mask_codons(nt_sequence))
# We must mask any codons containing N's because we need neutral probs to
# do simulation:
mask = tuple(codon_mask_tensor_of(chain_nt_seq) for chain_nt_seq in nt_sequence)
rates, csps = trimmed_shm_outputs_of_parent_pair(neutral_crepe, nt_sequence)
selection_factors = selection_crepe([aa_seqs])[0]
if selection_crepe.model.hyperparameters["output_dim"] == 1:
# Need to upgrade single selection factor to 20 selection factors, all
# equal except for the one for the parent sequence, which should be
# 1 (0 in log space).
new_selection_factors = []
for aa_seq, old_selection_factors in zip(aa_seqs, selection_factors):
if len(aa_seq) == 0:
new_selection_factors.append(torch.empty(0, 20, dtype=old_selection_factors.dtype))
else:
parent_indices = aa_idx_tensor_of_str_ambig(aa_seq)
# print(old_selection_factors)
new_selection_factors.append(
# Selection factors are expected to be in linear space here
molevol.lift_to_per_aa_selection_factors(old_selection_factors, parent_indices)
)
selection_factors = tuple(new_selection_factors)
parent_nt_idxs = tuple(
nt_idx_tensor_of_str(nt_chain_seq.replace("N", "A")) for nt_chain_seq in nt_sequence
)
codon_probs = []
for parent_idxs, nt_csps, nt_rates, sel_matrix in zip(parent_nt_idxs, csps, rates, selection_factors):
if len(parent_idxs) > 0:
nt_mut_probs = 1.0 - torch.exp(-branch_length * nt_rates)
codon_mutsel, _ = molevol.build_codon_mutsel(
parent_idxs.reshape(-1, 3),
nt_mut_probs.reshape(-1, 3),
nt_csps.reshape(-1, 3, 4),
sel_matrix,
multihit_model=multihit_model,
)
codon_probs.append(molevol.zero_stop_codon_probs(molevol.flatten_codons(clamp_probability(codon_mutsel))))
else:
codon_probs.append(torch.empty(0, 64, dtype=torch.float32))
return tuple(codon_probs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces additional metadata fields to saved models and updates tests, fixtures, and model constructors accordingly. Key changes include:
- Adding metadata keys (multihit_model_name, neutral_model_name, train_timestamp, model_type) to model initialization and hyperparameters.
- Updating tests and fixtures across multiple files to accommodate the new metadata.
- Enhancing pretrained model loading with a new multihit models dictionary and associated utility functions.
Reviewed Changes
Copilot reviewed 26 out of 26 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
tests/* | Update fixtures and test references for new metadata and defaults |
netam/models.py | Extend model constructors and hyperparameters with metadata |
netam/pretrained.py | Add PRETRAINED_MULTIHIT_MODELS dict and load_multihit function |
netam/molevol.py, netam/hit_class.py | Update processing of mutation probabilities with multihit support |
netam/framework.py, others | Various adjustments to integrate metadata into the workflow |
Comments suppressed due to low confidence (1)
netam/models.py:580
- Consider requiring 'model_type' as a mandatory argument instead of defaulting to None and issuing a warning, to enforce explicit model typing and simplify downstream logic.
def __init__(self, output_dim: int = 1, known_token_count: int = MAX_AA_TOKEN_IDX + 1, neutral_model_name: str = DEFAULT_NEUTRAL_MODEL, multihit_model_name: str = DEFAULT_MULTIHIT_MODEL, train_timestamp: str = None, model_type: str = None):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing! Some very minor things here.
netam/common.py
Outdated
@@ -67,6 +67,10 @@ def clamp_probability(x: Tensor) -> Tensor: | |||
return torch.clamp(x, min=SMALL_PROB, max=(1.0 - SMALL_PROB)) | |||
|
|||
|
|||
def clamp_probability_above(x: Tensor) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this function name concordant with the function below. They are both above, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to clamp_probability_above_only
DEFAULT_NEUTRAL_MODEL = "ThriftyHumV0.2-59" | ||
DEFAULT_MULTIHIT_MODEL = None | ||
# # ATTENTION!!! when done with dnsm retrainings, switch back to this: | ||
# DEFAULT_MULTIHIT_MODEL = "ThriftyHumV0.2-59-hc-tangshm" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a calendar event or something to remind us? 😁
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just created one for Friday!
result = self.apply_multihit_correction( | ||
parent_codon_idxs, uncorrected_codon_probs | ||
) | ||
# clamp only above to avoid summing a bunch of small fake values when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we pull this out into a function with a name and a slightly clearer docstring? This is a little on the opaque side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added more comments instead
netam/models.py
Outdated
): | ||
"""Apply the correction to the uncorrected codon probabilities. | ||
|
||
Unlike `forward` this does |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line wrapping
@@ -320,7 +324,7 @@ def build_codon_mutsel( | |||
codon_probs = codon_probs_of_mutation_matrices(mut_matrices) | |||
|
|||
if multihit_model is not None: | |||
codon_probs = multihit_model(parent_codon_idxs, codon_probs) | |||
codon_probs = multihit_model.forward(parent_codon_idxs, codon_probs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have thought these were identical.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are, but I like the style of an explicit named method call for code searching purposes
@@ -383,7 +387,7 @@ def neutral_codon_probs( | |||
codon_probs = codon_probs_of_mutation_matrices(mut_matrices) | |||
|
|||
if multihit_model is not None: | |||
codon_probs = multihit_model(parent_codon_idxs, codon_probs) | |||
codon_probs = multihit_model.forward(parent_codon_idxs, codon_probs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
tests/test_simulation.py
Outdated
print(flat_log_codon_mutsel[diff_mask]) | ||
assert False | ||
|
||
# adjusted_codon_probs = molevol.zero_stop_codon_probs(clamp_probability(adjusted_codon_probs.exp()).log()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not taken a close read of these functions-- I trust that they are doing what you want-- but perhaps you want to make a quick scan to tidy things up. Is this useful or cruft?
This PR adds the following values to metadata of saved models:
multihit_model_name
: expected to be a key innetam.pretrained.PRETRAINED_MULTIHIT_MODELS
. Defaults tonetam.models.DEFAULT_MULTIHIT_MODEL
. For crepes saved without this data, defaults toNone
.neutral_model_name
: expected to be a named pretrained neutral model. Defaults tonetam.models.DEFAULT_NEUTRAL_MODEL
. For crepes saved without this data, defaults toThriftyHumV0.2-59
.train_timestamp
: a UTC timestamp taken at the time of model initialization, if not provided explicitly (e.g.2025-05-01T22:05
). For crepes saved without this data, defaults toold
model_type
: eitherdnsm
,dasm
, orddsm
which must be provided at the time of model instantiation. For crepes saved without this data, defaults tounknown
, and will throw warnings.As hinted at above, I added a dictionary containing pretrained multihit models to
netam.pretrained
. These models can be accessed by name usingnetam.pretrained.load_multihit
.Requires companion PR https://github.com/matsengrp/dnsm-experiments-1/pull/132